# ================================================
# 0) Install & load required packages
# ================================================
# install.packages(c("haven", "broom", "dplyr", "ggplot2"))
library(haven)    # read_dta()
library(broom)    # tidy(), conf.int
library(dplyr)    # data-wrangling
library(ggplot2)  # plotting

# ================================================
# 1) Load data and add a country indicator
# ================================================
df <- read_dta("merged_file_nomissings.dta") %>%
  filter(!is.na(id_se) | !is.na(id_dk)) %>%
  mutate(country = case_when(
    !is.na(id_se) ~ "Sweden",
    !is.na(id_dk) ~ "Denmark"
  ))

# ================================================
# 2) Define outcomes and subgroup variables + labels
# ================================================
outcomes <- c("marry_n", "traits_n", "dislike_n")
outcome_labels <- c(
  marry_n    = "Social Distancing",
  traits_n   = "Trait Stereotyping",
  dislike_n  = "Party Dislike"
)

# Convert ideology variable into factor with three levels
df <- df %>%
  mutate(
    right = factor(right,
                   levels = c(1, 2, 3),
                   labels = c("Left", "Center", "Right"))
  )

# Define subgroups and labels
subgroups <- c("right", "gender", "HH_inc_d", "educ_d")
sub_labels <- c(
  right    = "Ideology",
  gender   = "Gender",
  HH_inc_d = "Income",
  educ_d   = "Education"
)

# ================================================
# 3) Fit interaction models and extract effects (FLAG)
# ================================================
results <- list()
for (ctry in c("Sweden", "Denmark")) {
  for (y in outcomes) {
    for (sg in subgroups) {
      dat <- filter(df, country == ctry)
      mod <- lm(as.formula(paste0(y, " ~ treat_flag * ", sg)), data = dat)
      
      # Handle multiple levels for ideology (factor)
      if (is.factor(dat[[sg]])) {
        # Create interaction terms for each level except baseline
        coefs <- tidy(mod, conf.int = TRUE) %>%
          filter(grepl(paste0("^treat_flag:"), term))
        
        # Map each coefficient to its level
        tt <- coefs %>%
          mutate(
            country  = ctry,
            outcome  = y,
            subgroup = sg,
            level    = gsub("treat_flag:", "", term)
          ) %>%
          select(country, outcome, subgroup, level, estimate, conf.low, conf.high, p.value)
        
      } else {
        # Standard 1vs0 interaction term
        trm <- paste0("treat_flag:", sg)
        tt <- tidy(mod, conf.int = TRUE) %>%
          filter(term == trm) %>%
          mutate(
            country  = ctry,
            outcome  = y,
            subgroup = sg,
            level    = sg
          ) %>%
          select(country, outcome, subgroup, level, estimate, conf.low, conf.high, p.value)
      }
      
      results[[paste(ctry, y, sg, sep = "_")]] <- tt
    }
  }
}

# --- Rensa och ometikettera nivåer snyggt ---
diff_df <- bind_rows(results) %>%
  mutate(
    subgroup_label = factor(sub_labels[subgroup], levels = unname(sub_labels)),
    country        = factor(country, levels = c("Sweden", "Denmark")),
    outcome        = factor(outcome, levels = outcomes),
    significant    = p.value < 0.05
  ) %>%
  # Fixa rena etiketter
  mutate(
    level_clean = case_when(
      subgroup == "right" & grepl("Left", level, ignore.case = TRUE)   ~ "Left",
      subgroup == "right" & grepl("Center", level, ignore.case = TRUE) ~ "Center",
      subgroup == "right" & grepl("Right", level, ignore.case = TRUE)  ~ "Right",
      subgroup == "gender"   ~ "Gender",
      subgroup == "educ_d"   ~ "Education",
      subgroup == "HH_inc_d" ~ "HH income",
      TRUE ~ level
    ),
    outcome_label = recode_factor(outcome, !!!outcome_labels)
  )

# --- Anpassa ordning så att ideologi-nivåerna hamnar i rätt vertikal ordning ---
diff_df <- diff_df %>%
  mutate(
    subgroup_label = recode(subgroup_label,
                            "Ideology" = "Ideology (Left–Right)"),
    level_clean = factor(level_clean,
                         levels = c("Left", "Center", "Right",
                                    "HH income", "Gender", "Education"))
  )

# --- Plotta rena etiketter utan prefix ---
ggplot(diff_df, aes(x = estimate, y = level_clean)) +
  geom_vline(xintercept = 0, linetype = "dashed") +
  geom_errorbarh(aes(xmin = conf.low, xmax = conf.high), height = 0) +
  geom_point(size = 3) +
  facet_grid(
    rows = vars(outcome_label),
    cols = vars(country),
    scales = "free_y",
    space = "free_y"
  ) +
  labs(
    x = "Conditional Marginal Means (Flag)",
    y = NULL
  ) +
  theme_minimal() +
  theme(
    strip.text       = element_text(face = "bold"),
    panel.grid.minor = element_blank(),
    axis.text.y      = element_text(size = 11),
    panel.spacing.y  = unit(0.8, "lines")
  )


###################################################################
# CAKE
###################################################################

results <- list()
for (ctry in c("Sweden", "Denmark")) {
  for (y in outcomes) {
    for (sg in subgroups) {
      dat <- filter(df, country == ctry)
      mod <- lm(as.formula(paste0(y, " ~ treat_cake * ", sg)), data = dat)
      
      if (is.factor(dat[[sg]])) {
        coefs <- tidy(mod, conf.int = TRUE) %>%
          filter(grepl(paste0("^treat_cake:"), term))
        
        tt <- coefs %>%
          mutate(
            country  = ctry,
            outcome  = y,
            subgroup = sg,
            level    = gsub("treat_cake:", "", term)
          ) %>%
          select(country, outcome, subgroup, level, estimate, conf.low, conf.high, p.value)
        
      } else {
        trm <- paste0("treat_cake:", sg)
        tt <- tidy(mod, conf.int = TRUE) %>%
          filter(term == trm) %>%
          mutate(
            country  = ctry,
            outcome  = y,
            subgroup = sg,
            level    = sg
          ) %>%
          select(country, outcome, subgroup, level, estimate, conf.low, conf.high, p.value)
      }
      
      results[[paste(ctry, y, sg, sep = "_")]] <- tt
    }
  }
}

# --- Rensa etiketter ---
diff_df <- bind_rows(results) %>%
  mutate(
    subgroup_label = factor(sub_labels[subgroup], levels = unname(sub_labels)),
    country        = factor(country, levels = c("Sweden", "Denmark")),
    outcome        = factor(outcome, levels = outcomes),
    significant    = p.value < 0.05
  ) %>%
  mutate(
    level_clean = case_when(
      subgroup == "right" & grepl("Left", level, ignore.case = TRUE)   ~ "Left",
      subgroup == "right" & grepl("Center", level, ignore.case = TRUE) ~ "Center",
      subgroup == "right" & grepl("Right", level, ignore.case = TRUE)  ~ "Right",
      subgroup == "gender"   ~ "Gender",
      subgroup == "educ_d"   ~ "Education",
      subgroup == "HH_inc_d" ~ "HH income",
      TRUE ~ level
    ),
    outcome_label = recode_factor(outcome, !!!outcome_labels)
  )

# --- Sätt ordningen på y-axeln ---
diff_df <- diff_df %>%
  mutate(
    subgroup_label = recode(subgroup_label,
                            "Ideology" = "Ideology (Left–Right)"),
    level_clean = factor(level_clean,
                         levels = c("Left", "Center", "Right",
                                    "HH income", "Gender", "Education"))
  )

# --- Plotta rena etiketter ---
ggplot(diff_df, aes(x = estimate, y = level_clean)) +
  geom_vline(xintercept = 0, linetype = "dashed") +
  geom_errorbarh(aes(xmin = conf.low, xmax = conf.high), height = 0) +
  geom_point(size = 3) +
  facet_grid(
    rows = vars(outcome_label),
    cols = vars(country),
    scales = "free_y",
    space = "free_y"
  ) +
  labs(
    x = "Conditional Marginal Means (Cake)",
    y = NULL
  ) +
  theme_minimal() +
  theme(
    strip.text       = element_text(face = "bold"),
    panel.grid.minor = element_blank(),
    axis.text.y      = element_text(size = 11),
    panel.spacing.y  = unit(0.8, "lines")
  )


########
#OTHER SUBGROUPS
########

================================================
# 0) Install & load required packages
# ================================================
# install.packages(c("haven", "broom", "dplyr", "ggplot2"))  # if not already installed
library(haven)    # read_dta()
library(broom)    # tidy(), conf.int
library(dplyr)    # data‐wrangling
library(ggplot2)  # plotting

# ================================================
# 1) Load data and add a country indicator
# ================================================
df <- read_dta("merged_file_nomissings.dta") %>%
  filter(!is.na(id_se) | !is.na(id_dk)) %>%
  mutate(country = case_when(
    !is.na(id_se) ~ "Sweden",
    !is.na(id_dk) ~ "Denmark"
  ))

# ================================================
# 2) Define outcomes and subgroup variables + labels
# ================================================
outcomes <- c("marry_n", "traits_n", "dislike_n")
outcome_labels <- c(
  marry_n    = "Social Distancing",
  traits_n   = "Trait Stereotyping",
  dislike_n  = "Party Dislike"
)

# ================================================
# Define demographic subgroup variables
# ================================================
subgroups <- c("right", "gender", "HH_inc_d", "educ_d")
sub_labels <- c(
  right    = "Ideology (Left–Right)",
  gender   = "Gender (Men=1, Women=0)",
  HH_inc_d = "High Income",
  educ_d   = "High Education"
)

# ================================================
# 3) Fit interaction models and extract 1vs0/1vs2/1vs3 effect (FLAG treatment)
# ================================================
results <- list()
for (ctry in c("Sweden", "Denmark")) {
  for (y in outcomes) {
    for (sg in subgroups) {
      dat  <- filter(df, country == ctry)
      mod  <- lm(as.formula(paste0(y, " ~ treat_flag * ", sg)), data = dat)
      trm  <- paste0("treat_flag:", sg)
      tt   <- tidy(mod, conf.int = TRUE) %>%
        filter(term == trm) %>%
        transmute(
          country   = ctry,
          outcome   = y,
          subgroup  = sg,
          estimate  = estimate,
          conf.low  = conf.low,
          conf.high = conf.high,
          p.value   = p.value
        )
      results[[paste(ctry, y, sg, sep = "_")]] <- tt
    }
  }
}

diff_df <- bind_rows(results) %>%
  mutate(
    subgroup_label = factor(sub_labels[subgroup], levels = unname(sub_labels)),
    country        = factor(country, levels = c("Sweden", "Denmark")),
    outcome        = factor(outcome, levels = outcomes),
    significant    = p.value < 0.05
  )

# ================================================
# 4) Combined plot: one column per country, one row per outcome
# ================================================
diff_df <- diff_df %>%
  mutate(
    outcome_label = recode_factor(
      outcome,
      !!!outcome_labels
    )
  )

ggplot(diff_df, aes(x = estimate, y = subgroup_label)) +
  geom_vline(xintercept = 0, linetype = "dashed") +
  geom_errorbarh(aes(xmin = conf.low, xmax = conf.high), height = 0) +
  geom_point(size = 3) +
  facet_grid(
    rows   = vars(outcome_label),
    cols   = vars(country),
    scales = "free_y"
  ) +
  labs(
    x = "Conditional Marginal Means (Flag)",
    y = NULL
  ) +
  theme_minimal() +
  theme(
    strip.text       = element_text(face = "bold"),
    panel.grid.minor = element_blank(),
    axis.text.y      = element_text(size = 11)
  )

###################################################################
# CAKE
###################################################################

results <- list()
for (ctry in c("Sweden", "Denmark")) {
  for (y in outcomes) {
    for (sg in subgroups) {
      dat  <- filter(df, country == ctry)
      mod  <- lm(as.formula(paste0(y, " ~ treat_cake * ", sg)), data = dat)
      trm  <- paste0("treat_cake:", sg)
      tt   <- tidy(mod, conf.int = TRUE) %>%
        filter(term == trm) %>%
        transmute(
          country   = ctry,
          outcome   = y,
          subgroup  = sg,
          estimate  = estimate,
          conf.low  = conf.low,
          conf.high = conf.high,
          p.value   = p.value
        )
      results[[paste(ctry, y, sg, sep = "_")]] <- tt
    }
  }
}

diff_df <- bind_rows(results) %>%
  mutate(
    subgroup_label = factor(sub_labels[subgroup], levels = unname(sub_labels)),
    country        = factor(country, levels = c("Sweden", "Denmark")),
    outcome        = factor(outcome, levels = outcomes),
    significant    = p.value < 0.05
  )

diff_df <- diff_df %>%
  mutate(
    outcome_label = recode_factor(
      outcome,
      !!!outcome_labels
    )
  )

ggplot(diff_df, aes(x = estimate, y = subgroup_label)) +
  geom_vline(xintercept = 0, linetype = "dashed") +
  geom_errorbarh(aes(xmin = conf.low, xmax = conf.high), height = 0) +
  geom_point(size = 3) +
  facet_grid(
    rows   = vars(outcome_label),
    cols   = vars(country),
    scales = "free_y"
  ) +
  labs(
    x = "Conditional Marginal Means (Cake)",
    y = NULL
  ) +
  theme_minimal() +
  theme(
    strip.text       = element_text(face = "bold"),
    panel.grid.minor = element_blank(),
    axis.text.y      = element_text(size = 11)
  )

                 